# run_ollama_ner.py
import pandas as pd
import ollama
import json
from tqdm import tqdm

# --- Configuration ---
INPUT_CSV = 'data/model_outputs/results_with_conflibert.csv'
OUTPUT_CSV = 'data/model_outputs/ner_inference_results.csv'
GEMMA_MODEL = 'gemma2:9b-instruct-q4_K_M'
LLAMA_MODEL = 'llama3.1:8b-instruct-q4_K_M'

# --- THE NEW, IMPROVED PROMPT TEMPLATE ---
NER_PROMPT_TEMPLATE = """You are an expert in Named Entity Recognition (NER) for analyzing texts about political conflict and events.
Your task is to identify and extract all named entities from the user's text according to the provided entity definitions.

**Entity Definitions:**
- **Organisation:** A formal group of people with a particular purpose (e.g., "United Nations", "ISIL coalition").
- **Person:** A specific individual's name (e.g., "Carter", "Abadi").
- **Location:** A geographical place, such as a city, country, or region (e.g., "Geneva", "Iraq").
- **Weapon:** A specific type of weapon or military equipment mentioned (e.g., "Javelin missile", "car bomb").
- **Nationality:** An adjective describing a person or group's origin (e.g., "Ukrainian", "Islamist").
- **Temporal:** A phrase indicating a specific time or date (e.g., "next week", "July 21, 2016").
- **DocumentReference:** A reference to a specific document (e.g., "Resolution 242").
- **Money:** A specific monetary value (e.g., "$10 million").
- **Quantity:** A number and a unit that is not money (e.g., "50 kilograms", "43 people").
- **MilitaryPlatform:** A major military asset like a ship or aircraft (e.g., "HMS Ocean", "F-16 fighter jet").

**Output Instructions:**
Return a single JSON object with a single key "entities". The value of "entities" must be a list of JSON objects, where each object represents a single extracted entity. Each object must have two keys:
1. "entity_text": The exact text of the entity as it appears in the source text.
2. "entity_label": The corresponding label from the definitions provided. Do not use B- or I- prefixes.

**Example:**
Text: "The Taliban attacked Kabul with rockets last Tuesday."
Output:
{{
  "entities": [
    {{
      "entity_text": "The Taliban",
      "entity_label": "Organisation"
    }},
    {{
      "entity_text": "Kabul",
      "entity_label": "Location"
    }},
    {{
      "entity_text": "rockets",
      "entity_label": "Weapon"
    }},
    {{
      "entity_text": "last Tuesday",
      "entity_label": "Temporal"
    }}
  ]
}}

**User Text to Analyze:**
"{text}"
"""

def get_llm_entities(client, model_name, text):
    try:
        response = client.chat(model=model_name, messages=[{'role': 'user', 'content': NER_PROMPT_TEMPLATE.format(text=text)}], format='json')
        data = json.loads(response['message']['content'])
        return data.get('entities', [])
    except Exception as e:
        print(f"Error on model {model_name}: {e}")
        return []

def main():
    print("Initializing Ollama client...")
    client = ollama.Client()
    df = pd.read_csv(INPUT_CSV)
    
    gemma_predictions = []
    llama_predictions = []

    for text in tqdm(df['text'], desc="Running Ollama Inference"):
        gemma_predictions.append(json.dumps(get_llm_entities(client, GEMMA_MODEL, str(text))))
        llama_predictions.append(json.dumps(get_llm_entities(client, LLAMA_MODEL, str(text))))
    
    df['gemma_entities'] = gemma_predictions
    df['llama_entities'] = llama_predictions
    
    df.to_csv(OUTPUT_CSV, index=False)
    print(f"\n✅ Ollama inference complete. Final results saved to {OUTPUT_CSV}")

if __name__ == "__main__":
    main()